# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Evaluator for testing.
"""
from inspect import signature
from tqdm.autonotebook import tqdm
from mindnlp import ms_jit
from mindnlp._legacy.abc import Metric
from .callbacks.callback_manager import CallbackManager, RunContext
from ...utils import ModelOutput, logging

logger = logging.get_logger(__name__)

class Evaluator:
    r"""
    Evaluator to test the model.


    Args:
        network (Cell): A network for evaluating.
        eval_dataset (Dataset): A evaluating dataset iterator.
        batc_size (int): numbers of samples in each batch.
        metrcis (Optional[list[Metric], Metric]): List of metric objects which should be used
            while evaluating. Default:None.
        callbacks (Optional[list[Callback], Callback]): List of callback objects which should be executed
            while training. Default: None.
        jit (bool): Whether use Just-In-Time compile.
    """
    def __init__(self, network, eval_dataset=None, metrics=None, callbacks=None, jit=False):
        r"""
        Initializes an instance of the Evaluator class.
        
        Args:
            self (Evaluator): The instance of the Evaluator class.
            network (object): The neural network model to be evaluated.
            eval_dataset (object, optional): The dataset used for evaluation. Default is None.
            metrics (list, optional): List of evaluation metrics to be used.
            callbacks (list, optional): List of callback functions to be executed during evaluation.
            jit (bool): Flag indicating whether just-in-time (JIT) compilation should be used.
        
        Returns:
            None. This method does not return any value.
        
        Raises:
            TypeError: If the metrics provided are not in the correct format.
            ValueError: If the evaluation dataset size is not valid.
            Exception: If any other unexpected error occurs during initialization.
        """
        self.network = network
        self.callbacks = callbacks
        self.earlystop = False

        self._check_metric_type(metrics)
        self.eval_dataset = eval_dataset
        self.total = eval_dataset.get_dataset_size()

        self.callback_manager = CallbackManager(callbacks=self.callbacks)
        self.eval_func = self._prepare_eval_func(network, jit)

    def _prepare_eval_func(self, network, jit):
        r"""
        Prepares the evaluation function for the Evaluator class.
        
        Args:
            self (Evaluator): The instance of the Evaluator class.
            network: The network object used for evaluation.
            jit (bool): A flag indicating whether to use just-in-time (JIT) compilation.
        
        Returns:
            None. This method does not return any value.
        
        Raises:
            None.
        
        This method prepares the evaluation function for the Evaluator class. It creates a core process function named '_run_step' 
        that takes any number of positional and keyword arguments. The function calls the 'network' object with the provided 
        arguments and stores the result in 'outputs'. If the 'outputs' is an instance of the 'ModelOutput' class, the method 
        returns the 'logits' attribute of the 'outputs'. Otherwise, it returns the 'outputs'. 
        
        If the 'jit' flag is set to True, the '_run_step' function is compiled using just-in-time (JIT) compilation by calling 
        'ms_jit' with the '_run_step' function as an argument. The compiled function is then returned. If the 'jit' flag is 
        False, the uncompiled '_run_step' function is returned.
        
        Note: The '_run_step' function is an internal function and is not intended to be called directly by external code.
        
        Example usage:
            evaluator = Evaluator()
            network = MyNetwork()
            jit = True
            evaluator._prepare_eval_func(network, jit)
        """
        def _run_step(*args, **kwargs):
            """Core process of each step."""
            outputs = network(*args, **kwargs)
            if isinstance(outputs, ModelOutput):
                return outputs.logits
            return outputs
        if jit:
            return ms_jit(_run_step)
        return _run_step

    def _check_metric_type(self, metrics):
        """Check metrics type."""
        self.metrics = []
        if not metrics:
            raise ValueError("For Evaluator, the model argument 'metrics' can not be None or empty, "
                             "you should set the argument 'metrics' for model.")
        if isinstance(metrics, Metric):
            self.metrics.append(metrics)
        elif isinstance(metrics, list):
            if all(isinstance(mc, Metric) for mc in metrics) is True:
                self.metrics = metrics
            else:
                obj = [not isinstance(mc, Metric) for mc in metrics][0]
                raise TypeError(f"Expect sub-classes of Metrics. Got {type(obj)}")
        else:
            raise TypeError(f"Expect metrics to be list or Metrics. Got {type(metrics)}.")

    def _check_amp_level_arg(self, optimizer, amp_level):
        """Check mixed-precision argument rules."""
        raise NotImplementedError

    def _check_for_graph_cell(self, kwargs):
        """Check network rules of GraphCell."""
        raise NotImplementedError

    def _build_boost_network(self, *kwargs):
        """Build boost network."""
        raise NotImplementedError

    def _check_reuse_dataset(self, dataset):
        """Check if dataset is being used by other models under the data sink mode."""
        if not hasattr(dataset, '__model_hash__'):
            dataset.__model_hash__ = hash(self)
        if hasattr(dataset, '__model_hash__') and dataset.__model_hash__ != hash(self):
            raise RuntimeError("The dataset object had been used in other model by model.train(...), "
                               "please create a new dataset.")

    def run(self, tgt_columns=None):
        """
        Evaluating function entry.

        Args:
            tgt_columns (Optional[list[str], str]): Target label column names for loss function.

        """
        args_dict = vars(self)
        run_context = RunContext(args_dict)
        self.callback_manager.evaluate_begin(run_context)
        self.clear_metrics()
        _ = self._run(tgt_columns)
        self.callback_manager.evaluate_end(run_context)
        self.earlystop = getattr(run_context, 'earlystop', False)

    def _run(self, tgt_columns=None):
        """Evaluating process for non-data sinking mode. The data would be passed to network directly."""
        self.network.set_train(False)
        with tqdm(total=self.total) as progress:
            progress.set_description('Evaluate')
            for data in self.eval_dataset.create_dict_iterator():
                inputs, tgts = self._data_process(data, tgt_columns)
                outputs = self.eval_func(**inputs)
                self._update_metrics(outputs, *tgts)
                progress.update(1)

        progress.close()
        metrics_result, metrics_names, metrics_values = self._get_metrics()

        print(f'Evaluate Score: {metrics_result}')
        return metrics_result, metrics_names, metrics_values

    def _run_ds_sink(self):
        """Evaluating process for data sinking mode."""
        raise NotImplementedError

    def _get_metrics(self):
        """Get all metrics values."""
        metrics = {}
        metrics_names = []
        metrics_values = []
        for metric in self.metrics:
            key = metric.get_metric_name()
            metrics_names.append(key)
            metrics[key] = metric.eval()
            metrics_values.append(metrics[key])
        return metrics, metrics_names, metrics_values

    def clear_metrics(self):
        """Clear metrics values."""
        for metric in self.metrics:
            metric.clear()

    def _update_metrics(self, outputs, *tgts):
        """Update metrics values."""
        logits = outputs[0] if isinstance(outputs, tuple) else outputs
        for metric in self.metrics:
            metric.update(logits, *tgts)
        return True

    def _data_process(self, data, tgt_columns):
        """Process data match the network construct"""
        sig = signature(self.network.construct)
        net_args = sig.parameters

        inputs = {}
        used_col = set()
        for arg in net_args:
            if arg in ('self', 'kwargs'):
                continue
            if arg not in data.keys():
                if str(net_args[arg])[-4:] != 'None':
                    raise ValueError(f'network inputs need `{arg}`, but not found in dataset columns.')
            else:
                inputs[arg] = data.get(arg)
                used_col.add(arg)

        # process target dataset.
        tgts = ()
        tgt_columns = self._prepare_tgt_columns(tgt_columns)
        for tgt_column in tgt_columns:
            if tgt_column in data:
                tgts = tgts + (data.pop(tgt_column),)
            else:
                raise ValueError(f'Not found `{tgt_column}` in dataset, please check dataset column names.')
        return inputs, tgts

    def _prepare_tgt_columns(self, tgt_columns):
        """Check and prepare target columns for training."""
        out_columns = []
        if tgt_columns is None:
            logger.warning("In the process of training model, tgt_column can not be None.")
            return []
        if isinstance(tgt_columns, str):
            out_columns.append(tgt_columns)
        elif isinstance(tgt_columns, list):
            for col in tgt_columns:
                if isinstance(col, str):
                    out_columns.append(col)
                else:
                    raise TypeError(f"Expect str of tgt_column. Got {type(col)}")
        else:
            raise TypeError(f"Expect tgt_columns to be list or str. Got {type(tgt_columns)}.")
        return out_columns
